--- title: "[old portfolio] Pytorch VAE_GAN" date: 2020-01-01 00:00:00 +0900 categories: jekyll update --- -->
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision
import copy
import glob
print("{0}ly cuda is available".format(torch.cuda.is_available()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name(0),torch.cuda.get_device_name(1)
transform=transforms.Compose([
transforms.ToTensor(),
#transforms.Lambda(lambda x: x.to(device))
])
Trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transform)
TrainLoad=torch.utils.data.DataLoader(Trainset,batch_size=1000)
Testset=torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=transform)
TestLoad=torch.utils.data.DataLoader(Testset,batch_size=1000)
# 60k training, 10k testing datapoints in total#
testiter=iter(TestLoad)
testim, testlb = testiter.next()
fig, ax = plt.subplots(1,5,figsize=(25,4))
for i in range(5):
ax[i].imshow(testim[i,0,:,:].cpu())
ax[i].set_title('{}'.format(testlb[i].numpy()),fontsize=35)
ax[i].axis('off')
testim[:].shape # 1000 per batch, 1 channel (greyscale), 28 x 28 pixels
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3,stride=(2,2))
self.conv2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=(2,2))
self.fc1=nn.Linear(64*6*6,120)
self.fc2=nn.Linear(120,10)
def forward(self, x):
x=F.relu(self.conv1(x))
x=F.relu(self.conv2(x))
x=x.view(-1,64*6*6)
x=F.relu(self.fc1(x))
x=self.fc2(x)
return(x)
def Cacc(): #test set accuracy
correct = 0
total = 0
with torch.no_grad():
for data in TestLoad:
images, labels = data
outputs = Cnet(images.to(device))
predicted = torch.argmax(outputs, 1)
total += labels.size(0)
correct += (predicted == labels.to(device)).sum().item()
#print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
return correct / total
Cnet=Classifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(Cnet.parameters(), lr= 0.01, momentum=0.9)
#optimizer = optim.Adam(Cnet.parameters(), lr= 0.0001)
ls=[]
accs=[Cacc()]
for epoch in range(20):
running_loss = 0.0
for i, data in enumerate(TrainLoad, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = Cnet(inputs.to(device))
loss = criterion(outputs, labels.to(device))
loss.backward()
optimizer.step()
ls.append(loss.to('cpu').detach().numpy())
# print statistics
running_loss += loss.item()
if i % 60 == 59:
print('[epoch %d] loss: %.3f' %
(epoch + 1, running_loss / 60))
running_loss = 0.0
accs.append(Cacc())
print('Finished Training')
matplotlib.rcParams.update({'lines.linewidth':3,'axes.linewidth':3,'xtick.major.width':3,'xtick.major.size':10,'ytick.major.width':3,'ytick.major.size':10})
ax=plt.subplots(figsize=(12,8))[1]
ax.plot(ls,color='blue')
ax.set_xlabel('batches',fontsize=30)
ax.set_ylabel('loss',fontsize=30,color='blue');
ax.tick_params(labelsize=20)
ax.tick_params(axis='y',labelcolor='blue')
axt=ax.twinx()
axt.plot(np.arange(len(accs))*(len(ls)/(len(accs)-1)),accs,color='red')
axt.set_ylabel('accuracy',fontsize=30,color='red')
axt.tick_params(axis='y',labelsize=20,labelcolor='red')
axt.set_ylim(0,1);
PATH = './MnistCnet.pth'
#torch.save(Cnet.state_dict(), PATH)
PATH = './MnistCnet.pth'
Cnet.load_state_dict(torch.load(PATH))
class Encoder(nn.Module):
def __init__(self,latent_dim=50):
super().__init__()
self.conv1=nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3,stride=(2,2))
self.conv2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=(2,2))
self.fc1=nn.Linear(64*6*6,2*latent_dim)
def forward(self, x):
x=F.relu(self.conv1(x))
x=F.relu(self.conv2(x))
x=x.view(-1,64*6*6)
x=self.fc1(x)
return(x)
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.epsilon=torch.tensor([0]).to(device)
def forward(self,x):
self.epsilon= torch.randn((x.size()[0],int(x.size()[1]/2))).to(device)
x=x[:,:int(x.size()[1]/2)] + self.epsilon*torch.exp(x[:,int(x.size()[1]/2):]*0.5)
return x
Enet=Encoder(latent_dim=50).to(device)
Snet=Sampler().to(device)
Enet.conv1.load_state_dict(Cnet.conv1.state_dict())
Enet.conv2.load_state_dict(Cnet.conv2.state_dict())
#uncomment to fix the convolutional network part
#Enet.conv1.requires_grad = False
#Enet.conv2.requires_grad = False
#Enet.conv1.requires_grad|Enet.conv2.requires_grad
del Cnet, accs, ls
class Decoder(nn.Module):
def __init__(self,latent_dim=50):
super().__init__()
self.fc1=nn.Linear(latent_dim,64*6*6)
self.deconv1=nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=3,stride=(2,2))
self.deconv2=nn.ConvTranspose2d(in_channels=32,out_channels=1,kernel_size=3,stride=(2,2),output_padding=(1,1))
def forward(self, x):
x=F.relu(self.fc1(x))
x=x.view(-1,64,6,6)
x=F.relu(self.deconv1(x))
x=self.deconv2(x)
return(x)
Dnet=Decoder(latent_dim=50).to(device)
def logpdf(x, mean, logvar):
return torch.einsum('ij->i',-0.5*((x-mean)**2/torch.exp(logvar) + logvar + torch.log(torch.tensor([2])*np.pi).to(device) ))
def KLD(x,z):
return logpdf(x,z[:,:int(z.size()[1]/2)],z[:,int(z.size()[1]/2):])-logpdf(x,torch.tensor(0.0),torch.tensor(0.0))
criterion = nn.MSELoss()
#optimizer = optim.SGD(list(Dnet.parameters())+list(Enet.parameters()), lr= 0.01, momentum=0.9)
optimizer = optim.Adam(list(Dnet.parameters())+list(Enet.parameters()), lr= 0.001)
display= 10
A, B = 1,0
for epoch in range(50):
running_loss = 0.0
for i, data in enumerate(TrainLoad, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
inputs=inputs.to(device)
code = Enet(inputs)
sample = Snet(code)
gen_im = Dnet(sample)
Closs=criterion(gen_im, inputs)
Kloss=torch.mean(KLD(sample,code))
loss = A*Closs + B*Kloss
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 60 == 59 and epoch%display==display-1:
print('[epoch %d] running loss: %.4f reconstruction loss: %.4f KLD loss: %.4f' %
(epoch + 1, running_loss / 60 / display, A*Closs, B*Kloss))
running_loss = 0.0
print('Finished Training')
fig = plt.figure(figsize=(12,12))
gs=gridspec.GridSpec(2,2,fig)
ax0=fig.add_subplot(gs[0,0])
ax1=fig.add_subplot(gs[0,1])
ax2=fig.add_subplot(gs[1,0])
ax3=fig.add_subplot(gs[1,1])
ax0.imshow(testim[0,0,:,:])
ax1.imshow(testim[1,0,:,:])
ax2.imshow(Dnet(Snet(Enet(testim.to(device)))).to('cpu')[0,0,:,:].detach())
ax3.imshow(Dnet(Snet(Enet(testim.to(device)))).to('cpu')[1,0,:,:].detach())
bx0=fig.add_subplot(gs[0,:])
bx0.set_title("Test Image",fontsize=72,pad=50,va='center')
bx0.axis('off')
bx1=fig.add_subplot(gs[1,:])
bx1.set_title("Reconstructed Image",fontsize=72,pad=50,va='center')
bx1.axis('off')
fig.tight_layout()
fig, ax =plt.subplots(1,10,figsize=(30,3.5))
a=Enet(testim.to(device))[0,:]
b=Enet(testim.to(device))[1,:]
for i in range(10):
ax[i].imshow(Dnet(Snet((a*(1-i/10)+b*(i/10)).unsqueeze(0))).to('cpu')[0,0,:,:].detach())
ax[i].axis('off')
bx=fig.add_subplot(ax[0].get_gridspec()[:])
bx.set_title('A series of reconstructions of the weighted sums in the latent space',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
IM=Dnet(Snet(torch.zeros(16,100).to(device)))
fig, ax =plt.subplots(4,4,figsize=(10,10))
for i in range(16):
ax[int(i/4%4)][i%4].imshow(IM.to('cpu')[i,0,:,:].detach())
ax[int(i/4%4)][i%4].axis('off')
bx=fig.add_subplot(ax[0][0].get_gridspec()[:])
bx.set_title('Random latent space sampling',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
del IM
#torch.save(Enet.state_dict(), './MnistEnet.pth')
#torch.save(Dnet.state_dict(), './MnistDnet.pth')
Enet.load_state_dict(torch.load('./MnistEnet.pth'))
Dnet.load_state_dict(torch.load('./MnistDnet.pth'))
criterion = nn.MSELoss()
optimizer = optim.SGD(list(Dnet.parameters())+list(Enet.parameters()), lr= 0.01, momentum=0.9)
display= 10
A, B = 1,4.e-4
running_loss = 0.0
for epoch in range(300):
for i, data in enumerate(TrainLoad, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
inputs=inputs.to(device)
code = Enet(inputs)
sample = Snet(code)
gen_im = Dnet(sample)
Closs=criterion(gen_im, inputs)
Kloss=torch.mean(KLD(sample,code))
loss = A*Closs + B*Kloss
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 60 == 59 and epoch%display==display-1:
print('[epoch %d] running loss: %.4f reconstruction loss: %.4f KLD loss: %.4f' %
(epoch + 1, running_loss / 60/ display, A*Closs, B*Kloss))
running_loss = 0.0
print('Finished Training')
#torch.save(Enet.state_dict(), './MnistVEnet.pth')
#torch.save(Dnet.state_dict(), './MnistVDnet.pth')
Enet.load_state_dict(torch.load('./MnistVEnet.pth'))
Dnet.load_state_dict(torch.load('./MnistVDnet.pth'))
fig = plt.figure(figsize=(12,12))
gs=gridspec.GridSpec(2,2,fig)
ax0=fig.add_subplot(gs[0,0])
ax1=fig.add_subplot(gs[0,1])
ax2=fig.add_subplot(gs[1,0])
ax3=fig.add_subplot(gs[1,1])
ax0.imshow(testim[0,0,:,:])
ax1.imshow(testim[1,0,:,:])
ax2.imshow(Dnet(Snet(Enet(testim.to(device)))).to('cpu')[0,0,:,:].detach())
ax3.imshow(Dnet(Snet(Enet(testim.to(device)))).to('cpu')[1,0,:,:].detach())
bx0=fig.add_subplot(gs[0,:])
bx0.set_title("Test Image",fontsize=72,pad=50,va='center')
bx0.axis('off')
bx1=fig.add_subplot(gs[1,:])
bx1.set_title("Reconstructed Image",fontsize=72,pad=50,va='center')
bx1.axis('off')
fig.tight_layout()
fig, ax =plt.subplots(1,10,figsize=(30,3.5))
a=Enet(testim.to(device))[0,:]
b=Enet(testim.to(device))[1,:]
for i in range(10):
ax[i].imshow(Dnet(Snet((a*(1-i/10)+b*(i/10)).unsqueeze(0))).to('cpu')[0,0,:,:].detach())
ax[i].axis('off')
bx=fig.add_subplot(ax[0].get_gridspec()[:])
bx.set_title('A series of reconstructions of the weighted sums in the latent space',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
IM=Dnet(Snet(torch.zeros(16,100).to(device)))
fig, ax =plt.subplots(4,4,figsize=(10,10))
for i in range(16):
ax[int(i/4%4)][i%4].imshow(IM.to('cpu')[i,0,:,:].detach())
ax[int(i/4%4)][i%4].axis('off')
bx=fig.add_subplot(ax[0][0].get_gridspec()[:])
bx.set_title('Random latent space sampling',fontsize=40,va='center',pad=20)
bx.axis('off')
fig.tight_layout()
del IM
Enet2d=Encoder(latent_dim=2).to(device)
Dnet2d=Decoder(latent_dim=2).to(device)
Enet2d.conv1.load_state_dict(Cnet.conv1.state_dict())
Enet2d.conv2.load_state_dict(Cnet.conv2.state_dict())
#Enet2d.conv1.requires_grad = False
#Enet2d.conv2.requires_grad = False
#Enet2d.conv1.requires_grad|Enet.conv2.requires_grad
display= 50
Ac, Bc = 1., 1.e-3
running_loss = 0.0
epochs=500
A=[]
criterion = nn.MSELoss()
optimizer = optim.SGD(list(Dnet2d.parameters())+list(Enet2d.parameters()), lr= 0.01, momentum=0.9)
for epoch in range(epochs):
for i, data in enumerate(TrainLoad, start=0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
inputs=inputs.to(device)
code = Enet2d(inputs)
sample = Snet(code)
gen_im = Dnet2d(sample)
Closs=criterion(gen_im, inputs)
Kloss=torch.mean(KLD(sample,code))
loss = Ac*Closs + Bc*Kloss
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 60 == 59 and epoch%display==display-1:
print('[epoch %d] running loss: %.4f reconstruction loss: %.4f KLD loss: %.4f' %
(epoch + 1, running_loss / 60/ display, Ac*Closs, Bc*Kloss))
running_loss = 0.0
A.append(Enet2d(testim.to(device)).to('cpu').detach())
print('Finished Training')
#torch.save(Enet2d.state_dict(), './MnistVEnet2d.pth')
#torch.save(Dnet2d.state_dict(), './MnistVDnet2d.pth')
Enet2d.load_state_dict(torch.load('./MnistVEnet2d.pth'))
Dnet2d.load_state_dict(torch.load('./MnistVDnet2d.pth'))
import matplotlib.animation as animation
fig = plt.figure(figsize=(12,12))
scat=plt.scatter(A[0][:,0],A[0][:,1])
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('Epoch : 000')
lblist=testlb.detach().numpy()
def animate(n):
fig.clear()
if n%20==19:
print('Frame {0:03d}'.format(n+1))
for i,point in enumerate(A[10*n]):
lb=lblist[i]
plt.scatter(point[0],point[1],s=100*np.linalg.norm(point[2:4]),marker='${0}$'.format(lb),c=[list(plt.get_cmap('hsv')(0.1*lb))])
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('Epoch : {0:03d}'.format((n+1)*10), fontsize=10)
ani=animation.FuncAnimation(fig, animate,100)
ani.save("./figure/2dtrain_v1.gif",writer='pillow')